import argparse
import torch
import numpy as np
import gym
import os
import random
from tqdm import tqdm
from torch.optim import Adam
import matplotlib.pyplot as plt
from noda.noda import NODA, NODANoPartial
from sac.sac import ReplayBuffer
import pdb


def compute_loss_model(model, data):
    o, a, r, o2, d = data['obs'], data['act'], data['rew'], data['obs2'], data['done']
    o2_pred, r_pred, o_recon = model(o, a)
    loss_o_pred = ((o2_pred - o2) ** 2).mean(dim=1)
    loss_r_pred = (r_pred - r) ** 2
    loss_o_recon = ((o_recon - o) ** 2).mean(dim=1)
    loss_model = 0.5 * (loss_o_pred + loss_o_recon) + 0.5 * loss_r_pred
    return loss_model


def get_buffer(env, steps, device, max_ep_len=1000):
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape[0]
    buffer = ReplayBuffer(obs_dim=obs_dim, act_dim=act_dim, size=steps, device=device)
    o, ep_ret, ep_len = env.reset(), 0, 0
    with tqdm(total=steps, desc='Generating data') as t:
        for i in range(steps):
            a = env.action_space.sample()
            o2, r, d, _ = env.step(a)
            ep_ret += r
            ep_len += 1
            d = False if ep_len == max_ep_len else d
            buffer.store(o, a, r, o2, d)
            o = o2
            if d or (ep_len == max_ep_len):
                o, ep_ret, ep_len = env.reset(), 0, 0
            t.update()
    return buffer


def prior_knowledge_plot(args, target_path):
    results = np.load(target_path, allow_pickle=True)
    test_loss_mean = results['test_loss_mean']
    test_loss_std = results['test_loss_std']
    lats_noda = args.lats_noda
    save_path = args.save_dir + '/'
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(7, 5))
    ax.plot(lats_noda, test_loss_mean, linewidth=3.5)
    ax.fill_between(lats_noda, test_loss_mean - test_loss_std, test_loss_mean + test_loss_std, alpha=0.3)
    ax.set_xlabel('Latent Dimension')
    ax.set_ylabel('Testing Loss')
    ax.grid(True)
    plt.title('Ant-v3')
    save_path += 'testing-prior-knowledge_' + args.env + '_' + str(args.hid_noda_ae) + '_' + \
                  str(args.hid_noda_ode) + '_' + str(args.env_steps) + '_' + str(args.model_steps) + '.pdf'
    plt.savefig(save_path)
    plt.close()


def prior_knowledge_main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='Ant-v3')
    parser.add_argument('--seed', '-s', type=int, default=0)
    parser.add_argument('--exp_name', type=str, default='noda')
    parser.add_argument('--lats-noda', nargs='+', default=[5, 10, 20, 40, 80, 160, 320, 640])
    parser.add_argument('--hid-noda-ae', type=int, default=256)
    parser.add_argument('--hid-noda-ode', type=int, default=64)
    parser.add_argument('--env-steps', type=int, default=20000)
    parser.add_argument('--batch-size', type=int, default=256)
    parser.add_argument('--model-steps', type=int, default=3000)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--retrain', action='store_true', default=False)
    parser.add_argument('--save-dir', default='results/prior_knowledge', type=str)
    args = parser.parse_args()
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    os.environ['PYTHONHASHSEED'] = str(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    if not os.path.isdir(args.save_dir):
        os.makedirs(args.save_dir)
    target_path = args.save_dir + '/prior_knowledge_results_' + args.env + '_' + str(args.hid_noda_ae) + '_' + \
                  str(args.hid_noda_ode) + '_' + str(args.env_steps) + '_' + str(args.model_steps) + '.npz'
    if os.path.isfile(target_path) and not args.retrain:
        prior_knowledge_plot(args, target_path)
        return None
    env = gym.make(args.env)
    models = [NODANoPartial(env.observation_space, env.action_space,
                            latent_dim=lat_noda, hidden_dim_ode=args.hid_noda_ode,
                            hidden_dim_ae=args.hid_noda_ae).to(args.device) for lat_noda in args.lats_noda]
    train_buffer = get_buffer(env, args.env_steps, args.device)
    test_buffer = get_buffer(env, args.env_steps, args.device)
    test_loss_mean = []
    test_loss_std = []
    for i in range(len(models)):
        model = models[i]
        model_optimizer = Adam(model.parameters(), lr=args.lr)
        with tqdm(total=args.model_steps) as t:
            for step in range(args.model_steps):
                ixs = torch.randperm(args.env_steps)[:args.batch_size]
                loss = compute_loss_model(model, train_buffer.get_batch(ixs))
                model_optimizer.zero_grad()
                loss.mean().backward()
                model_optimizer.step()
                t.set_postfix(train_loss='{:.9f}'.format(loss.mean().item()))
                t.update()
        with torch.no_grad():
            test_loss = compute_loss_model(model, test_buffer.get_batch(torch.randperm(args.env_steps))).cpu().numpy()
            test_loss_mean.append(np.mean(test_loss))
            test_loss_std.append(np.std(test_loss))
            print(np.mean(test_loss), np.std(test_loss))
    test_loss_mean = np.array(test_loss_mean)
    test_loss_std = np.array(test_loss_std)
    np.savez(target_path, test_loss_mean=test_loss_mean, test_loss_std=test_loss_std)
    prior_knowledge_plot(args, target_path)


if __name__ == '__main__':
    plt.rcParams['font.sans-serif'] = ['Times New Roman']
    plt.rcParams.update({'figure.autolayout': True})
    plt.rc('font', size=23)
    prior_knowledge_main()
